Skip to content

Implicit Gemm NVFP4 on Conv3D#886

Merged
jingyu-ml merged 18 commits intomainfrom
jingyux/implicit-gemm-nvfp4
Mar 14, 2026
Merged

Implicit Gemm NVFP4 on Conv3D#886
jingyu-ml merged 18 commits intomainfrom
jingyux/implicit-gemm-nvfp4

Conversation

@jingyu-ml
Copy link
Copy Markdown
Contributor

@jingyu-ml jingyu-ml commented Feb 13, 2026

What does this PR do?

Type of change: new feature

Overview:

Experimental Conv3D implicit-GEMM CUDA kernel with optional NVFP4-style (E2M1 + FP8 E4M3 scale) fake quantization for activations.

It is intended for research/prototyping and quantization-accuracy experiments only, not production deployment.
The implementation runs as a JIT-compiled PyTorch extension, mirrors conv3d output shape, and provides a quantized and non-quantized path to compare numerical behavior.

There is currently no real quantized production kernel integration in the formal ModelOpt export/compress/runtime stack; this path is kept in experimental/ for fake-quant accuracy validation and benchmarking.

Usage

import torch

from experimental.conv.implicit_gemm_cuda import conv3d_implicit_gemm_cuda
from modelopt.torch.quantization.tensor_quant import dynamic_block_quantize_op

x = torch.randn(1, 128, 21, 60, 106, device="cuda")
w = torch.randn(512, 128, 3, 3, 3, device="cuda")
block_size = 128

# Without FP4 activation quantization (drop-in-style Conv3D call)
out = conv3d_implicit_gemm_cuda(x, w, stride=(1, 1, 1), padding=(1, 1, 1))

# Optional FP4 block quantization of weights along the GEMM K dimension.
# The kernel's A-tile (activations) is quantized along K = Cin*kD*kH*kW,
# so weights must be flattened to [Cout, K] before quantizing to match.
Cout, Cin = w.shape[:2]
K = Cin * w.shape[2] * w.shape[3] * w.shape[4]
w_flat = w.reshape(Cout, K)
w_q_flat = dynamic_block_quantize_op(
    w_flat,
    block_size,
    w_flat.abs().max().unsqueeze(0),
    4,  # num_bits
    2,  # exponent_bits
    8,  # scale_num_bits
    4,  # scale_exponent_bits
)
w_q = w_q_flat.reshape_as(w)

# With FP4 activation fake quantization
out_q = conv3d_implicit_gemm_cuda(
    x,
    w_q,
    stride=(1, 1, 1),
    padding=(1, 1, 1),
    act_amax=x.abs().max().unsqueeze(0),
    quant_act=True,
    fp4_block_size=block_size,  # 128 or 256
)

Testing

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes/No
  • Did you write any new necessary tests?: Yes/No
  • Did you add or update any necessary documentation?: Yes/No
  • Did you update Changelog?: Yes/No

Additional Information

Summary by CodeRabbit

Release Notes

  • New Features

    • Added experimental Conv3D implementation with implicit GEMM acceleration and optional FP4 quantization support
    • Added benchmarking tool to compare 3D convolution performance across implementations
    • Enhanced quantization framework integration for Conv3D operations
  • Documentation

    • Added comprehensive guide for experimental Conv3D prototype, including supported scenarios, API reference, and current limitations

Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
@jingyu-ml jingyu-ml requested a review from a team as a code owner February 13, 2026 08:40
@jingyu-ml jingyu-ml marked this pull request as draft February 13, 2026 08:40
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Feb 13, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Feb 13, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This pull request introduces an experimental Conv3D implicit GEMM CUDA kernel implementation with optional FP4 quantization. It includes kernel code, Python bindings, a Python wrapper module, benchmarking tools, comprehensive tests, and integration with the existing quantization framework.

Changes

Cohort / File(s) Summary
Experimental Conv3D Implicit GEMM Implementation
experimental/conv/implicit_gemm_kernel.cu, experimental/conv/implicit_gemm_binding.cpp, experimental/conv/implicit_gemm_cuda.py
New CUDA kernel using BF16 WMMA tensor cores for 3D convolution with optional blockwise FP4 quantization support; includes PyTorch C++/CUDA bindings and high-level Python interface handling padding, reshaping, and dtype conversions. Exposes conv3d_implicit_gemm_cuda() and fp4_fake_quant() functions.
Documentation & Benchmarking
experimental/conv/README.md, experimental/conv/bench_implicit_gemm.py
New documentation describing the Conv3D implicit GEMM prototype status, API surface, and usage patterns; latency benchmark script comparing cuDNN and implicit GEMM implementations across multiple tensor shapes and quantization modes.
Testing Infrastructure
experimental/conv/test_implicit_gemm.py
Comprehensive test suite covering non-quantized and FP4-quantized paths, including correctness validation against cuDNN, edge cases, determinism checks, and cross-validation with external implementations (Triton, modelopt).
Quantization Integration
modelopt/torch/quantization/nn/modules/quant_conv.py
Enhanced _QuantConv3d to optionally route through experimental implicit GEMM path when NVFP4 quantization is enabled; adds condition checks and weight quantization helpers while preserving fallback to default cuDNN path.

Sequence Diagram

sequenceDiagram
    participant Python as Python Caller
    participant PyBind as PyTorch Bindings
    participant Kernel as CUDA Kernel
    participant Quant as Quantization Ops

    Python->>Python: Prepare input tensor, weights
    alt NVFP4 Quantization Enabled
        Python->>Quant: Get activation amax from input quantizer
        Quant-->>Python: Return amax value
        Python->>Quant: Quantize weights along K dimension
        Quant-->>Python: Return quantized weights
        Python->>PyBind: Call conv3d_implicit_gemm_cuda(x, w, act_amax, quant_act=True, fp4_block_size)
    else No Quantization
        Python->>PyBind: Call conv3d_implicit_gemm_cuda(x, w, act_amax=None, quant_act=False)
    end
    PyBind->>Kernel: Launch BF16 WMMA kernel with template config
    Kernel->>Kernel: Load A-tile (input) with optional FP4 quantization
    alt FP4 Quantization Active
        Kernel->>Kernel: Per-warp local max reduction
        Kernel->>Kernel: Compute quantization scales via FP8 round-trip
        Kernel->>Kernel: Quantize/dequantize A-tile values
    end
    Kernel->>Kernel: Load B-tile (weights)
    Kernel->>Kernel: Execute WMMA multiply-accumulate
    Kernel->>Kernel: Write output with L2 swizzle scheduling
    Kernel-->>PyBind: Return output tensor
    PyBind-->>Python: Return result
    Python->>Quant: Apply output quantizer if present
    Quant-->>Python: Return quantized output
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

🚥 Pre-merge checks | ✅ 4
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'Implicit Gemm NVFP4 on Conv3D' directly and clearly describes the main change: adding an implicit GEMM implementation with NVFP4 quantization support for Conv3D operations.
Docstring Coverage ✅ Passed Docstring coverage is 87.32% which is sufficient. The required threshold is 80.00%.
Security Anti-Patterns ✅ Passed No security anti-patterns detected: torch.load, numpy.load, trust_remote_code, eval/exec, nosec comments all absent.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch jingyux/implicit-gemm-nvfp4
📝 Coding Plan
  • Generate coding plan for human review comments

Comment @coderabbitai help to get the list of available commands and usage tips.

@jingyu-ml jingyu-ml self-assigned this Feb 13, 2026
@codecov
Copy link
Copy Markdown

codecov bot commented Feb 13, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 70.07%. Comparing base (2d7d1ec) to head (9d425cc).
⚠️ Report is 2 commits behind head on main.

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #886   +/-   ##
=======================================
  Coverage   70.07%   70.07%           
=======================================
  Files         221      221           
  Lines       25531    25531           
=======================================
  Hits        17892    17892           
  Misses       7639     7639           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
@jingyu-ml jingyu-ml marked this pull request as ready for review February 14, 2026 00:36
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Fix all issues with AI agents
In `@experimental/conv/implicit_gemm_cuda.py`:
- Around line 549-560: Add an explicit validation for the fp4_block_size
parameter so unsupported values don't silently use the 256 branch: at the start
of the Python wrapper function that accepts fp4_block_size (the function that
ultimately selects between the two LAUNCH_WMMA_KERNEL branches), check that
fp4_block_size is either 128 or 256 and raise a ValueError with a clear message
if not (e.g., "fp4_block_size must be 128 or 256, got {fp4_block_size}"). Ensure
this validation is performed before any kernel-launch logic or passing
fp4_block_size into the CUDA launch path.
- Around line 669-673: The code currently combines quant_act and act_amax into
do_quant, silently disabling quantization when quant_act is True but act_amax is
None; change this by adding an explicit guard: if quant_act is True and act_amax
is None raise a ValueError (e.g., "act_amax is required when quant_act=True") so
callers are notified, otherwise keep the existing behavior of creating amax_t
when do_quant is True; update the block around the symbols quant_act, act_amax,
do_quant, and amax_t accordingly.

In `@experimental/conv/README.md`:
- Line 76: The README table uses the constant-style name `FP4_BLOCK_SIZE` which
doesn't match the Python function parameter `fp4_block_size`; update the table
entry to use `fp4_block_size` (or explicitly list both forms if you want to
document the env/constant separately) so it matches the function signature and
avoids confusion when calling the function with keyword arguments; locate the
table row that currently shows `FP4_BLOCK_SIZE` and replace it with
`fp4_block_size` (or add a parenthetical note like `fp4_block_size
(FP4_BLOCK_SIZE)` if documenting both).
🧹 Nitpick comments (5)
experimental/conv/implicit_gemm_cuda.py (5)

134-138: Stale template-parameter comments — BLOCK_N is 64, not 32.

The comments on lines 135 and 138 say BLOCK_N = 32 and WARPS_N = 2, but every actual instantiation (lines 554, 559) uses BLOCK_N=64, WARPS_N=4. Similarly, the comment on line 423 says 64 * 32 * 4 = 8192 bytes when the real footprint is 64 * 64 * 4 = 16384 bytes. The code is correct (it's fully parameterized), but these stale comments will mislead anyone reading the kernel.


255-319: Quantized A-tile load: FP4 block size is implicitly coupled to BLOCK_K.

The quantize-dequantize path computes one block_max per warp-row (i.e., over BLOCK_K elements via warp_reduce_max). This means the FP4 quantization block size is always exactly BLOCK_K, which only works correctly because BLOCK_K == fp4_block_size for both supported configs. If a future config changes BLOCK_K independently of fp4_block_size, quantization granularity will silently break. Worth a brief comment or a static_assert in the kernel:

static_assert(BLOCK_K == 128 || BLOCK_K == 256, "BLOCK_K must match fp4_block_size");

578-591: verbose=True will spam build logs on every first invocation.

For an experimental module this is fine during development, but consider gating it behind an environment variable or defaulting to False so downstream users don't get unexpected compiler output.


641-646: Input validation uses bare assert, which is stripped under python -O.

The assert statements on lines 643 and 646 will be silently removed when Python runs with optimizations enabled. For a CUDA kernel wrapper, invalid shapes reaching the kernel could cause silent corruption or hard crashes. Consider using explicit checks:

Proposed fix
-    assert x.ndim == 5 and w.ndim == 5
+    if x.ndim != 5 or w.ndim != 5:
+        raise ValueError(f"Expected 5D tensors, got x.ndim={x.ndim}, w.ndim={w.ndim}")
     n_batch, cin, d, h, w_in = x.shape
     cout, cin_w, kd, kh, kw = w.shape
-    assert cin_w == cin
+    if cin_w != cin:
+        raise ValueError(f"Input channels mismatch: x has {cin}, w has {cin_w}")

663-667: All inputs are cast to .float() (FP32) — potential unnecessary memory doubling.

If inputs are already FP32, the .float().contiguous() calls are cheap. But if inputs arrive as BF16 (common for the use-case described), this silently doubles memory. The docstring says "BF16 WMMA" but the kernel actually consumes FP32 global-memory inputs and converts to BF16 only in shared memory. This is worth a brief comment so users understand the kernel is not end-to-end BF16 in global memory.

Comment thread experimental/conv/implicit_gemm_cuda.py Outdated
Comment thread experimental/conv/implicit_gemm_cuda.py
Comment thread experimental/conv/README.md Outdated
Comment thread experimental/conv/implicit_gemm_cuda.py Outdated


@torch.no_grad()
def conv3d_implicit_gemm_cuda(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should add some unittests for this to make sure it is doing what it is expected.

Copy link
Copy Markdown
Contributor Author

@jingyu-ml jingyu-ml Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added, but I only put them under the experimental dir

Comment thread experimental/conv/implicit_gemm_cuda.py Outdated
import torch.nn.functional as F

# C++ header for function declarations
CPP_SOURCE = r"""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is better to put these into another .cpp/cu file. And then use torch.utils.cpp_extension.load to load it. Then it will be easier to read the cpp/cu source code.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you! moved these code to cpp/cu files

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (2)
experimental/conv/implicit_gemm_binding.cpp (1)

1-20: Duplicate license header block.

The file contains two license header blocks (lines 1-16 and lines 18-19). Consider removing the duplicate.

Proposed fix
-// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-// SPDX-License-Identifier: Apache-2.0
-
 `#include` <torch/extension.h>
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@experimental/conv/implicit_gemm_binding.cpp` around lines 1 - 20, Remove the
duplicated SPDX/license header at the top of the file by keeping a single
canonical license block and deleting the second repeated block (the extra lines
containing "SPDX-FileCopyrightText" and "SPDX-License-Identifier" that appear
immediately after the first header). Ensure only one license header remains at
the top of implicit_gemm_binding.cpp and adjust surrounding whitespace so the
file begins with a single, correctly formatted header.
experimental/conv/implicit_gemm_kernel.cu (1)

1-32: Duplicate license header block.

Similar to the binding file, this contains duplicate license blocks (lines 1-16 and lines 18-31).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@experimental/conv/implicit_gemm_kernel.cu` around lines 1 - 32, The file
contains two repeated license header blocks starting with
"SPDX-FileCopyrightText" and "SPDX-License-Identifier" (the duplicate Apache-2.0
header); remove the redundant block so only a single license header remains at
the top of implicit_gemm_kernel.cu, keeping one copy of the SPDX lines and the
Apache-2.0 boilerplate and deleting the second copy.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@experimental/conv/implicit_gemm_binding.cpp`:
- Around line 1-20: Remove the duplicated SPDX/license header at the top of the
file by keeping a single canonical license block and deleting the second
repeated block (the extra lines containing "SPDX-FileCopyrightText" and
"SPDX-License-Identifier" that appear immediately after the first header).
Ensure only one license header remains at the top of implicit_gemm_binding.cpp
and adjust surrounding whitespace so the file begins with a single, correctly
formatted header.

In `@experimental/conv/implicit_gemm_kernel.cu`:
- Around line 1-32: The file contains two repeated license header blocks
starting with "SPDX-FileCopyrightText" and "SPDX-License-Identifier" (the
duplicate Apache-2.0 header); remove the redundant block so only a single
license header remains at the top of implicit_gemm_kernel.cu, keeping one copy
of the SPDX lines and the Apache-2.0 boilerplate and deleting the second copy.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: e666c7a0-60f1-43a2-997e-847422fa1deb

📥 Commits

Reviewing files that changed from the base of the PR and between fcb4571 and 7ca8bd6.

📒 Files selected for processing (5)
  • experimental/conv/bench_implicit_gemm.py
  • experimental/conv/implicit_gemm_binding.cpp
  • experimental/conv/implicit_gemm_cuda.py
  • experimental/conv/implicit_gemm_kernel.cu
  • experimental/conv/test_implicit_gemm.py

Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds an experimental Conv3D implicit-GEMM CUDA kernel (BF16 WMMA) with an optional fused FP4 (E2M1) activation fake-quant path intended for research/benchmarking, plus accompanying Python wrapper, tests, docs, and a small benchmark CLI.

Changes:

  • Introduce a JIT-compiled PyTorch CUDA extension implementing Conv3D via implicit GEMM with optional fused FP4 activation fake quantization.
  • Add standalone FP4 fake-quant kernel/binding for validation and cross-checking against ModelOpt/Triton paths.
  • Add experimental README, benchmark script, and a comprehensive (but experimental) test suite.

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
experimental/conv/implicit_gemm_kernel.cu Implements the WMMA-based Conv3D implicit-GEMM kernel and a standalone FP4 fake-quant kernel.
experimental/conv/implicit_gemm_binding.cpp Exposes the CUDA functions to Python via pybind.
experimental/conv/implicit_gemm_cuda.py Python API + JIT compilation/loading of the CUDA extension; reshaping/padding glue code.
experimental/conv/test_implicit_gemm.py Correctness, determinism, and FP4 validation tests (incl. optional Triton/ModelOpt cross-checks).
experimental/conv/bench_implicit_gemm.py CLI microbenchmark comparing cuDNN Conv3D vs implicit GEMM (quant/non-quant).
experimental/conv/README.md Experimental usage notes, API description, and limitations.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread experimental/conv/test_implicit_gemm.py
Comment thread experimental/conv/implicit_gemm_cuda.py
Comment thread experimental/conv/implicit_gemm_cuda.py Outdated
Comment thread experimental/conv/implicit_gemm_cuda.py
Comment thread experimental/conv/implicit_gemm_kernel.cu Outdated
Comment thread experimental/conv/implicit_gemm_kernel.cu
Comment thread experimental/conv/implicit_gemm_cuda.py
Comment thread experimental/conv/implicit_gemm_kernel.cu
Copy link
Copy Markdown
Contributor

@Edwardf0t1 Edwardf0t1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can also add some results where we applied nvfp4 implicit GEMM in readme.

Comment thread experimental/conv/README.md Outdated
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
@jingyu-ml jingyu-ml requested a review from a team as a code owner March 9, 2026 22:58
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
modelopt/torch/quantization/nn/modules/quant_conv.py (1)

125-162: Bypasses parent's quantize_weight() context manager pattern.

The parent class QuantLinearConvBase.forward uses a with self.quantize_weight(): context that sets _enable_weight_quantization. This custom forward bypasses that pattern entirely by directly calling _nvfp4_quantize_weight_along_k.

While this may be intentional for the NVFP4 path (since weight quantization is handled differently), consider adding a brief comment explaining why the context manager pattern is not used here, to prevent confusion for future maintainers.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/nn/modules/quant_conv.py` around lines 125 - 162,
The NVFP4 forward bypasses the parent QuantLinearConvBase.forward's with
self.quantize_weight() context and directly calls _nvfp4_quantize_weight_along_k
without setting _enable_weight_quantization; add a short comment in forward
(near the call to _nvfp4_quantize_weight_along_k) that explains this is
intentional because NVFP4 performs weight quantization differently (so the
quantize_weight() context and its _enable_weight_quantization flag must not be
used), and reference both quantize_weight() and _enable_weight_quantization to
help future maintainers understand why the parent pattern is omitted.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/quantization/nn/modules/quant_conv.py`:
- Around line 146-156: The implicit GEMM CUDA path doesn't support grouped
convolutions and can produce incorrect results for Conv3d with groups>1; update
_should_use_implicit_gemm() to return False when self.groups != 1 so the forward
path falls back to the standard cuDNN implementation; mention and guard against
the conv3d_implicit_gemm_cuda kernel (no groups parameter) and the weight
flattening in _nvfp4_quantize_weight_along_k (assumes non-grouped weight layout)
when adding this check.

---

Nitpick comments:
In `@modelopt/torch/quantization/nn/modules/quant_conv.py`:
- Around line 125-162: The NVFP4 forward bypasses the parent
QuantLinearConvBase.forward's with self.quantize_weight() context and directly
calls _nvfp4_quantize_weight_along_k without setting
_enable_weight_quantization; add a short comment in forward (near the call to
_nvfp4_quantize_weight_along_k) that explains this is intentional because NVFP4
performs weight quantization differently (so the quantize_weight() context and
its _enable_weight_quantization flag must not be used), and reference both
quantize_weight() and _enable_weight_quantization to help future maintainers
understand why the parent pattern is omitted.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 5a8d8522-f312-43e6-a4f1-bfb65ac9eebd

📥 Commits

Reviewing files that changed from the base of the PR and between 66278df and 6ba7802.

📒 Files selected for processing (3)
  • examples/diffusers/quantization/models_utils.py
  • examples/diffusers/quantization/utils.py
  • modelopt/torch/quantization/nn/modules/quant_conv.py

Comment thread modelopt/torch/quantization/nn/modules/quant_conv.py Outdated
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (3)
experimental/conv/test_implicit_gemm.py (1)

26-31: Consider adding a CUDA availability check in the fixture.

The fixture imports the CUDA module unconditionally. If tests are run on a system without CUDA, the import will fail with a non-obvious error. Consider wrapping with a CUDA availability check.

Proposed enhancement
 `@pytest.fixture`(scope="module")
 def cuda_conv3d():
     """Import and return the CUDA implicit GEMM conv3d function."""
+    if not torch.cuda.is_available():
+        pytest.skip("CUDA not available")
     from experimental.conv.implicit_gemm_cuda import conv3d_implicit_gemm_cuda

     return conv3d_implicit_gemm_cuda
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@experimental/conv/test_implicit_gemm.py` around lines 26 - 31, The fixture
cuda_conv3d currently imports conv3d_implicit_gemm_cuda unconditionally which
will error on machines without CUDA; modify the fixture to first check CUDA
availability (e.g., via pytest.importorskip or torch.cuda.is_available()) and
skip the tests if CUDA is not present, then import or return
conv3d_implicit_gemm_cuda only when CUDA is available so the test suite fails
gracefully on non-CUDA hosts.
experimental/conv/implicit_gemm_kernel.cu (1)

1-32: Duplicate license header.

The license header appears twice (lines 1-16 in block comment style and lines 18-31 in line comment style). Remove one of them.

Proposed fix: Remove the duplicate header
 /*
  * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  * SPDX-License-Identifier: Apache-2.0
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
  * You may obtain a copy of the License at
  *
  * http://www.apache.org/licenses/LICENSE-2.0
  *
  * Unless required by applicable law or agreed to in writing, software
  * distributed under the License is distributed on an "AS IS" BASIS,
  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */

-// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-// SPDX-License-Identifier: Apache-2.0
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
 // Conv3D Implicit GEMM with BF16 WMMA Tensor Cores and optional fused FP4 quantization.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@experimental/conv/implicit_gemm_kernel.cu` around lines 1 - 32, The file
contains a duplicated license header: a block comment header (/* ... */)
followed immediately by the same SPDX/license text in line-comment form (//
...). Remove one of the two headers so only a single canonical license header
remains at the top of implicit_gemm_kernel.cu (either keep the block comment or
the line-comment version), ensuring the SPDX identifiers and full license text
are preserved in the remaining header.
experimental/conv/implicit_gemm_cuda.py (1)

182-212: Add block_size validation to fp4_fake_quant for consistency with conv3d_implicit_gemm_cuda.

The fp4_fake_quant function accepts any block_size value without validation, while conv3d_implicit_gemm_cuda restricts fp4_block_size to {16, 32, 64, 128, 256}. Since the underlying CUDA kernel and documentation confirm only these five sizes are supported, add explicit validation to fail fast on invalid inputs and match the pattern used elsewhere in the module.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@experimental/conv/implicit_gemm_cuda.py` around lines 182 - 212,
fp4_fake_quant currently accepts any block_size but the CUDA kernel only
supports block sizes {16, 32, 64, 128, 256}; add a fast-fail validation at the
start of fp4_fake_quant (similar to conv3d_implicit_gemm_cuda's fp4_block_size
checks) that raises a clear error if block_size is not one of those five values,
referencing the allowed set in the message so callers know valid options.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@experimental/conv/implicit_gemm_cuda.py`:
- Around line 182-212: fp4_fake_quant currently accepts any block_size but the
CUDA kernel only supports block sizes {16, 32, 64, 128, 256}; add a fast-fail
validation at the start of fp4_fake_quant (similar to
conv3d_implicit_gemm_cuda's fp4_block_size checks) that raises a clear error if
block_size is not one of those five values, referencing the allowed set in the
message so callers know valid options.

In `@experimental/conv/implicit_gemm_kernel.cu`:
- Around line 1-32: The file contains a duplicated license header: a block
comment header (/* ... */) followed immediately by the same SPDX/license text in
line-comment form (// ...). Remove one of the two headers so only a single
canonical license header remains at the top of implicit_gemm_kernel.cu (either
keep the block comment or the line-comment version), ensuring the SPDX
identifiers and full license text are preserved in the remaining header.

In `@experimental/conv/test_implicit_gemm.py`:
- Around line 26-31: The fixture cuda_conv3d currently imports
conv3d_implicit_gemm_cuda unconditionally which will error on machines without
CUDA; modify the fixture to first check CUDA availability (e.g., via
pytest.importorskip or torch.cuda.is_available()) and skip the tests if CUDA is
not present, then import or return conv3d_implicit_gemm_cuda only when CUDA is
available so the test suite fails gracefully on non-CUDA hosts.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 5cc45eab-7f8d-4cca-af07-61bb7df13470

📥 Commits

Reviewing files that changed from the base of the PR and between 6ba7802 and 52bb60d.

📒 Files selected for processing (4)
  • experimental/conv/README.md
  • experimental/conv/implicit_gemm_cuda.py
  • experimental/conv/implicit_gemm_kernel.cu
  • experimental/conv/test_implicit_gemm.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • experimental/conv/README.md

Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
@jingyu-ml jingyu-ml requested a review from Edwardf0t1 March 10, 2026 00:01
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Comment on lines +118 to +119
try:
from experimental.conv.implicit_gemm_cuda import conv3d_implicit_gemm_cuda # noqa: F401
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We cannot import from experimental since its not part of modelopt library. How do you plan to use this? Setting PYTHONPATH? Is this for local testing only?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be in this MR. I accidentally included it. This change should be here:

https://github.com/NVIDIA/Model-Optimizer/tree/jingyux/implicit-gemm-nvfp4-e2e

deleting.

Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Copy link
Copy Markdown
Contributor

@Edwardf0t1 Edwardf0t1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@jingyu-ml jingyu-ml enabled auto-merge (squash) March 14, 2026 03:40
@jingyu-ml jingyu-ml merged commit 6f32d24 into main Mar 14, 2026
32 checks passed
@jingyu-ml jingyu-ml deleted the jingyux/implicit-gemm-nvfp4 branch March 14, 2026 04:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants